{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU",
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "32e7669cd82042cbbb419e25db606c1d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_b6698be32bf74c4087e129fab6e13fdd",
              "IPY_MODEL_ff7333b35c1c472482df6550f6e43be2",
              "IPY_MODEL_da4df56a1ba440dbb69087d0019cab1d"
            ],
            "layout": "IPY_MODEL_ad598693c58549e0a83a1328d77b8f83"
          }
        },
        "b6698be32bf74c4087e129fab6e13fdd": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_de2f7a60851f4681877a4c8dccba29cc",
            "placeholder": "​",
            "style": "IPY_MODEL_02b296efbff143f4bfbb904cbc7b1109",
            "value": "Loading checkpoint shards: 100%"
          }
        },
        "ff7333b35c1c472482df6550f6e43be2": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_72ac83e43e2b4d4498070a5b701a5572",
            "max": 3,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_320fa615d4de4652ac34fc2518f7749e",
            "value": 3
          }
        },
        "da4df56a1ba440dbb69087d0019cab1d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_75280ef205a245be92da268e0752dc71",
            "placeholder": "​",
            "style": "IPY_MODEL_3f33eabd6f7f46ef8138abe748d8fbb1",
            "value": " 3/3 [01:06&lt;00:00, 18.14s/it]"
          }
        },
        "ad598693c58549e0a83a1328d77b8f83": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "de2f7a60851f4681877a4c8dccba29cc": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "02b296efbff143f4bfbb904cbc7b1109": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "72ac83e43e2b4d4498070a5b701a5572": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "320fa615d4de4652ac34fc2518f7749e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "75280ef205a245be92da268e0752dc71": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "3f33eabd6f7f46ef8138abe748d8fbb1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mi50mprVsU_P"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from google.colab import userdata\n",
        "os.environ[\"HF_TOKEN\"] = userdata.get('HF_TOKEN')"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!pip3 install -q -U bitsandbytes==0.42.0\n",
        "!pip3 install -q -U peft==0.8.2\n",
        "!pip3 install -q -U trl==0.7.10\n",
        "!pip3 install -q -U accelerate==0.27.1\n",
        "!pip3 install -q -U datasets==2.17.0\n",
        "!pip3 install -q -U transformers==4.38.1"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-5gJk3W_s0RY",
        "outputId": "ca3d427e-5bfc-4635-f27a-e49e56718f7e"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting git+https://****@github.com/huggingface/new-model-addition-golden-gate@add-golden-gate\n",
            "  Cloning https://****@github.com/huggingface/new-model-addition-golden-gate (to revision add-golden-gate) to /tmp/pip-req-build-8jci0sy8\n",
            "  Running command git clone --filter=blob:none --quiet 'https://****@github.com/huggingface/new-model-addition-golden-gate' /tmp/pip-req-build-8jci0sy8\n",
            "  Running command git checkout -b add-golden-gate --track origin/add-golden-gate\n",
            "  Switched to a new branch 'add-golden-gate'\n",
            "  Branch 'add-golden-gate' set up to track remote branch 'add-golden-gate' from 'origin'.\n",
            "  Resolved https://****@github.com/huggingface/new-model-addition-golden-gate to commit e9d36beb5fcafeb2ac327a68eee82009d24cb58f\n",
            "  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
            "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
            "  Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers==4.38.0.dev0) (3.13.1)\n",
            "Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /usr/local/lib/python3.10/dist-packages (from transformers==4.38.0.dev0) (0.20.3)\n",
            "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.38.0.dev0) (1.25.2)\n",
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers==4.38.0.dev0) (23.2)\n",
            "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.38.0.dev0) (6.0.1)\n",
            "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.38.0.dev0) (2023.12.25)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers==4.38.0.dev0) (2.31.0)\n",
            "Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers==4.38.0.dev0) (0.15.2)\n",
            "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.38.0.dev0) (0.4.2)\n",
            "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers==4.38.0.dev0) (4.66.2)\n",
            "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.19.3->transformers==4.38.0.dev0) (2023.6.0)\n",
            "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.19.3->transformers==4.38.0.dev0) (4.9.0)\n",
            "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.38.0.dev0) (3.3.2)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.38.0.dev0) (3.6)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.38.0.dev0) (2.0.7)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.38.0.dev0) (2024.2.2)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer\n",
        "\n",
        "model_id = \"google/gemma-7b\"\n",
        "bnb_config = BitsAndBytesConfig(\n",
        "    load_in_4bit=True,\n",
        "    bnb_4bit_quant_type=\"nf4\",\n",
        "    bnb_4bit_compute_dtype=torch.bfloat16\n",
        ")\n",
        "\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_TOKEN'])\n",
        "model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={\"\":0}, token=os.environ['HF_TOKEN'])"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 49,
          "referenced_widgets": [
            "32e7669cd82042cbbb419e25db606c1d",
            "b6698be32bf74c4087e129fab6e13fdd",
            "ff7333b35c1c472482df6550f6e43be2",
            "da4df56a1ba440dbb69087d0019cab1d",
            "ad598693c58549e0a83a1328d77b8f83",
            "de2f7a60851f4681877a4c8dccba29cc",
            "02b296efbff143f4bfbb904cbc7b1109",
            "72ac83e43e2b4d4498070a5b701a5572",
            "320fa615d4de4652ac34fc2518f7749e",
            "75280ef205a245be92da268e0752dc71",
            "3f33eabd6f7f46ef8138abe748d8fbb1"
          ]
        },
        "id": "EVEotZX8s-v6",
        "outputId": "e378234f-f56f-483e-c569-f3a196c02370"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "32e7669cd82042cbbb419e25db606c1d"
            }
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "text = \"Quote: Imagination is more\"\n",
        "device = \"cuda:0\"\n",
        "inputs = tokenizer(text, return_tensors=\"pt\").to(device)\n",
        "\n",
        "outputs = model.generate(**inputs, max_new_tokens=20)\n",
        "print(tokenizer.decode(outputs[0], skip_special_tokens=True))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "7Msk610TVUGW",
        "outputId": "8c14afe0-dc6e-42b1-d05a-1a7a6c2ace9e"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Quote: Imagination is more important than knowledge. Knowledge is limited. Imagination encircles the world.\n",
            "\n",
            "-Albert Einstein\n",
            "\n",
            "I\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "os.environ[\"WANDB_DISABLED\"] = \"true\""
      ],
      "metadata": {
        "id": "Mi2P12KsVbyt"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from peft import LoraConfig\n",
        "\n",
        "lora_config = LoraConfig(\n",
        "    r=8,\n",
        "    target_modules=[\"q_proj\", \"o_proj\", \"k_proj\", \"v_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
        "    task_type=\"CAUSAL_LM\",\n",
        ")"
      ],
      "metadata": {
        "id": "7lzjoG3KVRMN"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "\n",
        "data = load_dataset(\"Abirate/english_quotes\")\n",
        "data = data.map(lambda samples: tokenizer(samples[\"quote\"]), batched=True)"
      ],
      "metadata": {
        "id": "HPQSpLNAuubn"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import transformers\n",
        "from trl import SFTTrainer\n",
        "\n",
        "def formatting_func(example):\n",
        "    text = f\"Quote: {example['quote'][0]}\\nAuthor: {example['author'][0]}\"\n",
        "    return [text]\n",
        "\n",
        "trainer = SFTTrainer(\n",
        "    model=model,\n",
        "    train_dataset=data[\"train\"],\n",
        "    args=transformers.TrainingArguments(\n",
        "        per_device_train_batch_size=1,\n",
        "        gradient_accumulation_steps=4,\n",
        "        warmup_steps=2,\n",
        "        max_steps=10,\n",
        "        learning_rate=2e-4,\n",
        "        fp16=True,\n",
        "        logging_steps=1,\n",
        "        output_dir=\"outputs\",\n",
        "        optim=\"paged_adamw_8bit\"\n",
        "    ),\n",
        "    peft_config=lora_config,\n",
        "    formatting_func=formatting_func,\n",
        ")\n",
        "trainer.train()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 530
        },
        "id": "HFbR2FIgVfiT",
        "outputId": "ba27fbda-54be-415c-ee47-78632e4ad4c6"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).\n",
            "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:223: UserWarning: You didn't pass a `max_seq_length` argument to the SFTTrainer, this will default to 1024\n",
            "  warnings.warn(\n",
            "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:290: UserWarning: You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code.\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='10' max='10' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [10/10 00:08, Epoch 6/10]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>1.700500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.641000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>1.031500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>4</td>\n",
              "      <td>0.945800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>5</td>\n",
              "      <td>0.516200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>6</td>\n",
              "      <td>1.278600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>7</td>\n",
              "      <td>1.187300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>8</td>\n",
              "      <td>0.339000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>9</td>\n",
              "      <td>0.724500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>10</td>\n",
              "      <td>0.647600</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "TrainOutput(global_step=10, training_loss=0.9011982649564743, metrics={'train_runtime': 10.2202, 'train_samples_per_second': 3.914, 'train_steps_per_second': 0.978, 'total_flos': 5520965345280.0, 'train_loss': 0.9011982649564743, 'epoch': 6.67})"
            ]
          },
          "metadata": {},
          "execution_count": 8
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "text = \"Quote: Imagination is\"\n",
        "device = \"cuda:0\"\n",
        "inputs = tokenizer(text, return_tensors=\"pt\").to(device)\n",
        "\n",
        "outputs = model.generate(**inputs, max_new_tokens=20)\n",
        "print(tokenizer.decode(outputs[0], skip_special_tokens=True))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "f5Mim0lNViwe",
        "outputId": "4534ee26-63e3-4ced-ee27-673f0b9d7afb"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Quote: Imagination is more important than knowledge. Knowledge is limited. Imagination encircles the world.\n",
            "\n",
            "Author: Albert Einstein\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "djg3QAMuVx8R"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}
